0XFF 伸展树简介
伸展树是一种二叉搜索树(二叉排序树),通过旋转(多次Splay操作)保持平衡,主要用于维护LCT(动态树,Link-Cut Tree)。LCT将在以后的学习笔记中详细讲解。
0X00 伸展树能干什么
最基础的:
- 插入节点k
- 删除节点k
- 求全局的k小值(根据排名求数)
- 求k的全局排名(根据数求排名)
- 求k的前驱(严格小于k的最大的数)
- 求k的后继(严格大于k的最小的数)
详见:离线模板、强制在线模板。(本人已经使用无旋fhq-Treap、带旋Treap、Splay三种方式AC离线版,下面将详细介绍Splay方法)
进阶版(加tag):
- 给定一个初始数列
- 每次操作将区间 [l, r] 翻转
详见:文艺平衡树
但其实进阶版比模板要短,大概是因为进阶版操作少,去掉了一系列不必要的操作。
0X01 普通平衡树
0X01-01 数据存储
const int MAXN = 100001;
int fa[MAXN], key[MAXN], son[MAXN][2], size[MAXN], cnt[MAXN];
int rt, sz;
fa数组存每个节点当前的父亲,key是节点权值,son的0和1分别表示左儿子和右儿子,size是以该节点为根的子树的大小,cnt是该数的数量。
rt存当前的根,sz是点的个数。
0X01-02 清空一棵子树
清空以x为根的子树,只需要把所有数据都清零即可。
inline void clear(int x) {fa[x] = key[x] = son[x][0] = son[x][1] = size[x] = cnt[x] = 0;}
0X01-03 获取节点与父亲的关系
这里用于获取节点x与父亲的关系,即是左儿子还是右儿子。
inline bool get(int x) {return son[fa[x]][1] == x;}
0X01-04 更新子树大小
这里是更新子树大小的操作,因为需要用到儿子的size,所以每次更新前请确保所有孩子的值都是最新的!
void update(int x)
{
if(!x) return;//数据不合法
size[x] = cnt[x];
if(son[x][0]) size[x] += size[son[x][0]];
if(son[x][1]) size[x] += size[son[x][1]];
}
0X01-05 连边
将x设为y的关系为z的儿子。
void connect(int x, int y, int z)
{
if(x) fa[x] = y;//把x的父亲设为y
if(y) son[y][z] = x;//把y的关系为z的儿子设为x
}
0X01-06 上旋
这个操作是将节点x上旋,详见注释。
void rotate(int x)
{
int f = fa[x], ff = fa[f], p = get(x), q = get(f);//获取父亲、爷爷、自己与父亲关系、父亲与爷爷关系
connect(son[x][p^1], f, p);//把自己的p^1儿子连边到父亲,关系为p
connect(f, x, p^1);//把父亲下旋到自己,关系为p^1
connect(x, ff, q);//把自己上旋到爷爷,关系为q
update(f);//为什么是这个顺序?因为这次操作把父亲旋转为自己的儿子,更新顺序就是原父亲到自己
update(x);
}
上旋操作不太好理解,这里放几张图上来。
下面这个是旋转前的树,边权标的是关系。

现在假设我们要上旋3号点,可以看到,3号点与父亲(2号点)的关系为0,异或为1,于是把3号点的右儿子(5号点)给父亲,此时因为父亲的儿子更新,3号点孤立:

把父亲连到自己,成为自己的右儿子,因为父亲的fa被更新,1号点孤立:

把自己连到爷爷,关系为0:

可以看到3号点被上旋。
0X01-07 splay操作维护平衡
在这里,每次把节点旋转到根即可,但是在文艺平衡树中要指定旋转终点。这次为了省事,直接写旋转到根的即可:
void splay(int x)
{
for(int f;f=fa[x];rotate(x)) if(fa[f]) rotate(get(x)==get(f)?f:x);
rt = x;
}
一直上旋点x直到上旋到跟,把root(rt)改为当前点。
splay操作维护平衡性,写平衡树记得有事没事splay一下,毕竟splay多了用不了多长时间,splay少了会出错。
0X01-08 插入节点
- 情况1:如果不存在根(即一个数都没有),直接在根插入。
- 情况2:按照排序树性质向下查找,发现有过该权值的点,更新大小即可,记得需要splay。
- 情况3:查找到空节点发现不存在该节点,在最后那里加入这个点,更新信息即可,也需要splay。
void insert(int x)
{
if(!rt)//case 1
{
rt = ++sz;
key[rt] = x;
cnt[rt] = size[rt] = 1;
son[rt][0] = son[rt][1] = 0;
return;
}
int u = rt, f = 0;
while(true)
{
if(key[u] == x)//case 2
{
++cnt[u];
update(u);
update(f);
splay(u);
return;
}
f = u, u = son[u][x>key[u]];//排序树性质
if(!u)//case 3
{
key[++sz] = x;
cnt[sz] = size[sz] = 1;
fa[sz] = f;
son[f][x>key[f]] = sz;
update(f);
splay(sz);
return;
}
}
}
0X01-09 查询数的排名
查询数的排名根据排序树性质判断是否在左子树,之后累加size即可。
int find(int x)
{
int u = rt, ans = 0;
while(true)
{
if(x < key[u])//在左子树,当前点走到左儿子
{
u = son[u][0];
continue;
}
ans += size[son[u][0]];//不在左子树,ans加上左子树大小
if(x == key[u])//找到该节点,splay维持平衡,返回答案
{
splay(u);
return ans + 1;
}
ans += cnt[u];//在右子树,ans加上当前位置数个数,向右查找
u = son[u][1];
}
}
0X01-10 查询排名对应的数
这个操作与上一个类似,每次把要查的排名减少,查找子树即可。
int kth(int x)
{
int u = rt;
while(true)
{
if(son[u][0] && x <= size[son[u][0]])
{
u = son[u][0];
continue;
}
if(son[u][0]) x -= size[son[u][0]];
if(x <= cnt[u])
{
splay(u);
return key[u];
}
x -= cnt[u];
u = son[u][1];
}
}
0X01-11 前驱
根据前驱定义与排序树性质,找到左儿子并一直向右走即可。
int pre()
{
int u = son[rt][0];
while(son[u][1]) u = son[u][1];
return u;
}
0X01-12 后继
后继同理。
int suc()
{
int u = son[rt][1];
while(son[u][0]) u = son[u][0];
return u;
}
0X01-13 删除节点
- 情况1:当前值有多个数,cnt更新即可。
- 情况2:没有左右儿子,直接删除。
- 情况3:只有右儿子,右儿子取代当前节点,把当前节点删除。
- 情况4:只有左儿子同理。
- 情况5:左右儿子都有,把前驱转到根节点,右儿子连到前驱成为前驱的右儿子,左儿子在旋转过程中旋转到其他地方,直接清除即可。
void del(int x)
{
int useless = find(x);//find把当前点转到根,方便操作
if(cnt[rt] > 1)//case 1
{
--cnt[rt];
update(rt);
return;
}
if(!son[rt][0] && !son[rt][1])//case 2
{
clear(rt);
rt = 0;
return;
}
if(!son[rt][0])//case 3
{
int tmp = rt;
fa[rt=son[rt][1]] = 0;
clear(tmp);
return;
}
if(!son[rt][1])//case 4
{
int tmp = rt;
fa[rt=son[rt][0]] = 0;
clear(tmp);
return;
}
int tmp = rt, p = pre();//case 5
splay(p);
connect(son[tmp][1], rt, 1);
clear(tmp);
update(rt);
}
0X01-14 完整程序
//By: Luogu@rui_er(122461)
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100001;
int fa[MAXN], key[MAXN], son[MAXN][2], size[MAXN], cnt[MAXN];
int rt, sz;
inline int read()
{
char c;
int s = 1, w = 0;
c = getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') s = -1;
for(;isdigit(c);c=getchar()) w = (w << 3) + (w << 1) + (c ^ 48);
return s * w;
}
inline void clear(int x) {fa[x] = key[x] = son[x][0] = son[x][1] = size[x] = cnt[x] = 0;}
inline bool get(int x) {return son[fa[x]][1] == x;}
void update(int x)
{
if(!x) return;
size[x] = cnt[x];
if(son[x][0]) size[x] += size[son[x][0]];
if(son[x][1]) size[x] += size[son[x][1]];
}
void connect(int x, int y, int z)
{
if(x) fa[x] = y;
if(y) son[y][z] = x;
}
void rotate(int x)
{
int f = fa[x], ff = fa[f], p = get(x), q = get(f);
connect(son[x][p^1], f, p);
connect(f, x, p^1);
connect(x, ff, q);
update(f);
update(x);
}
void splay(int x)
{
for(int f;f=fa[x];rotate(x)) if(fa[f]) rotate(get(x)==get(f)?f:x);
rt = x;
}
void insert(int x)
{
if(!rt)
{
rt = ++sz;
key[rt] = x;
cnt[rt] = size[rt] = 1;
son[rt][0] = son[rt][1] = 0;
return;
}
int u = rt, f = 0;
while(true)
{
if(key[u] == x)
{
++cnt[u];
update(u);
update(f);
splay(u);
return;
}
f = u, u = son[u][x>key[u]];
if(!u)
{
key[++sz] = x;
cnt[sz] = size[sz] = 1;
fa[sz] = f;
son[f][x>key[f]] = sz;
update(f);
splay(sz);
return;
}
}
}
int find(int x)
{
int u = rt, ans = 0;
while(true)
{
if(x < key[u])
{
u = son[u][0];
continue;
}
ans += size[son[u][0]];
if(x == key[u])
{
splay(u);
return ans + 1;
}
ans += cnt[u];
u = son[u][1];
}
}
int kth(int x)
{
int u = rt;
while(true)
{
if(son[u][0] && x <= size[son[u][0]])
{
u = son[u][0];
continue;
}
if(son[u][0]) x -= size[son[u][0]];
if(x <= cnt[u])
{
splay(u);
return key[u];
}
x -= cnt[u];
u = son[u][1];
}
}
int pre()
{
int u = son[rt][0];
while(son[u][1]) u = son[u][1];
return u;
}
int suc()
{
int u = son[rt][1];
while(son[u][0]) u = son[u][0];
return u;
}
void del(int x)
{
int useless = find(x);
if(cnt[rt] > 1)
{
--cnt[rt];
update(rt);
return;
}
if(!son[rt][0] && !son[rt][1])
{
clear(rt);
rt = 0;
return;
}
if(!son[rt][0])
{
int tmp = rt;
fa[rt=son[rt][1]] = 0;
clear(tmp);
return;
}
if(!son[rt][1])
{
int tmp = rt;
fa[rt=son[rt][0]] = 0;
clear(tmp);
return;
}
int tmp = rt, p = pre();
splay(p);
connect(son[tmp][1], rt, 1);
clear(tmp);
update(rt);
}
int main()
{
int n = read();
while(n--)
{
int opt = read(), x = read();
if(opt == 1) insert(x);
else if(opt == 2) del(x);
else if(opt == 3) printf("%d\n", find(x));
else if(opt == 4) printf("%d\n", kth(x));
else if(opt == 5)
{
insert(x);
printf("%d\n", key[pre()]);
del(x);
}
else
{
insert(x);
printf("%d\n", key[suc()]);
del(x);
}
}
return 0;
}
0X02 文艺平衡树
文艺平衡树只需要维护一个 tag 即可。
//By: Luogu@rui_er(122461)
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100005;
int n, m;
namespace Splay
{
int fa[MAXN] = {0}, key[MAXN] = {0}, son[MAXN][2] = {{0, 0}}, size[MAXN] = {0};
int que[MAXN] = {0}, tag[MAXN] = {0};
int rt, sz;
inline int read()
{
char c;
int s = 1, w = 0;
c = getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') s = -1;
for(;isdigit(c);c=getchar()) w = (w << 3) + (w << 1) + (c ^ 48);
return s * w;
}
inline void clear(int x) {fa[x] = key[x] = son[x][0] = son[x][1] = size[x] = 0;}
inline bool get(int x) {return son[fa[x]][1] == x;}
void update(int x)
{
if(!x) return;
size[x] = 1;
if(son[x][0]) size[x] += size[son[x][0]];
if(son[x][1]) size[x] += size[son[x][1]];
}
void pushdown(int x)
{
if(tag[x])
{
tag[x] = 0;
tag[son[x][0]] ^= 1;
tag[son[x][1]] ^= 1;
swap(son[x][0], son[x][1]);
}
}
void connect(int x, int y, int z) {fa[x] = y, son[y][z] = x;}
void rotate(int x)
{
int f = fa[x], ff = fa[f], p = get(x), q = get(f);
connect(x, ff, q);
connect(son[x][p^1], f, p);
connect(f, x, p^1);
update(f);
update(x);
}
void splay(int x, int y)
{
// cout<<"splay(): begin"<<endl;
int len = 0;
for(int i=x;i;i=fa[i]) que[++len] = i;
// cout<<"splay(): get queue"<<endl;
for(int i=len;i;i--) pushdown(que[i]);
// cout<<"splay(): after pushdown()"<<endl;
while(fa[x] != y)
{
int f = fa[x];
if(fa[f] != y) rotate(get(f)==get(x)?f:x);
rotate(x);
// cout<<"splay(): rotating---f: "<<f<<", x: "<<x<<", y: "<<y<<", fa[f]: "<<fa[f]
// <<endl;
}
// cout<<"splay(): after rotated"<<endl;
if(!y) rt = x;
// cout<<"splay(): end"<<endl;
}
void insert(int x)
{
int u = rt, f = 0;
while(u) f = u, u = son[u][x>key[u]];
u = ++sz;
key[u] = x;
fa[u] = f;
if(f) son[f][x>key[f]] = u;
size[u] = 1;
son[u][0] = son[u][1] = 0;
splay(u, 0);
}
int kth(int x)
{
int u = rt;
while(true)
{
pushdown(u);
if(size[son[u][0]] >= x) u = son[u][0];
else
{
x -= size[son[u][0]];
if(x == 1) return u;
--x;
u = son[u][1];
}
}
}
void reverse(int l, int r)
{
// cout<<"reverse(): begin"<<endl;
l = kth(l), r = kth(r+2);
// cout<<"reverse(): after kth()"<<endl;
splay(l, 0);
// cout<<"reverse(): after splay(l, 0)"<<endl;
splay(r, l);
// cout<<"reverse(): after splay(r, l)"<<endl;
tag[son[r][0]] ^= 1;
// cout<<"reverse(): end"<<endl;
}
void print(int x)
{
// cout<<"print(): begin"<<endl;
pushdown(x);
if(son[x][0]) print(son[x][0]);
if(key[x] > 1 && key[x] < n + 2) cout<<key[x]-1<<" ";
if(son[x][1]) print(son[x][1]);
// cout<<"print(): end"<<endl;
}
}
using namespace Splay;
int main()
{
n = read(), m = read();
for(int i=1;i<=n+2;i++) insert(i);
// for(int i=0;i<=n;i++) cout<<fa[i]<<" "; cout<<endl;
for(int i=1;i<=m;i++)
{
int l = read(), r = read();
reverse(l, r);
}
print(rt);
return 0;
}